import os
import sys
import copy
import argparse
import heapq
from time import time
from tqdm import tqdm

import torch

from utils import EarlyStopping
from ogb.nodeproppred import Evaluator
from torch_geometric.utils import to_undirected

from utils import *

parser = argparse.ArgumentParser()

parser.add_argument('--device', type=str, default='0', choices=['cpu', '0', '1', '2', '3'])

parser.add_argument('--dataset', type=str, default='cora',
                    choices=['cora', 'pubmed', 'ogbn-arxiv', 'ogbn-products'])
parser.add_argument('--feature', type=str, default="gpt_only_embedding",
                    choices=['raw', 'content_embedding', 'summary_embedding', 'gpt_response_embedding', 'gpt_only_embedding'])
parser.add_argument('--epochs', type=int, default=10000)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--early_stop', type=int, default=150) # the number of patience

parser.add_argument('--model', type=str, default='SAGE', choices=['MLP', 'GCN', 'SAGE'])
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0)
parser.add_argument('--wd', type=float, default=5e-2)

parser.add_argument('--verbose', type=int, default=1)

args = parser.parse_args()

if args.device != 'cpu': args.device = 'cuda:' + args.device
if not torch.cuda.is_available(): args.device= 'cpu'

def get_split_idx(dataset_folder, dataset, seed=0):
    with open(f'{dataset_folder}{dataset}_{seed}.txt', 'r') as fin:
        train, valid, test = fin.read().strip().split('\n')
        train = [int(x) for x in train.split(' ')]
        valid = [int(x) for x in valid.split(' ')]
        test = [int(x) for x in test.split(' ')]
    return train, valid, test

def train(model, optimizer, data, loss_func, eval_func):
    model.train()
    optimizer.zero_grad()

    logits = model(data.x, data.edge_index)
    loss = loss_func(
        logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    return loss.item()

@torch.no_grad()
def evaluate(model, data, eval_func):
    model.eval()
    logits = model(data.x, data.edge_index)
    train_acc = eval_func(
        logits[data.train_mask], data.y[data.train_mask])
    val_acc = eval_func(
        logits[data.val_mask], data.y[data.val_mask])
    test_acc = eval_func(
        logits[data.test_mask], data.y[data.test_mask])
    return train_acc.item(), val_acc.item(), test_acc.item(), logits

@torch.no_grad()
def evaluate_ogb(model, data, eval_func):
    model.eval()
    logits = model(data.x, data.edge_index)
    y_pred = logits.argmax(dim=-1, keepdim=True)
    train_acc = eval_func.eval(
        {'y_pred':y_pred[data.train_mask], 'y_true': data.y[data.train_mask].unsqueeze(-1)})['acc']
    val_acc = eval_func.eval(
        {'y_pred':y_pred[data.val_mask], 'y_true': data.y[data.val_mask].unsqueeze(-1)})['acc']
    test_acc = eval_func.eval(
        {'y_pred':y_pred[data.test_mask], 'y_true': data.y[data.test_mask].unsqueeze(-1)})['acc']
    return train_acc, val_acc, test_acc, logits

if args.model == "GCN":
    from GNNs.GCN.model import GCN as GNN
elif args.model == "SAGE":
    from GNNs.SAGE.model import SAGE as GNN
elif args.model == "MLP":
    from GNNs.MLP.model import MLP as GNN
else:
    exit(f"Model {args.model} is not supported!")

loss_func = torch.nn.CrossEntropyLoss()

def eval_func(output, labels): # ACC
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

split_folder = f'../raw_data/{args.dataset}/splits/'
dataset_folder = f'../raw_data/'
dataset2num_classes = {'cora': 7, 'pubmed': 3, 'ogbn-arxiv': 40, 'ogbn-products': 47}
from raw_data_utils.load import load_data
data, _, label2text = load_data(dataset=args.dataset, dataset_folder=dataset_folder)

data.edge_index = to_undirected(data.edge_index, data.num_nodes)
num_nodes = data.y.shape[0]
num_classes = dataset2num_classes[args.dataset]

if args.feature == 'raw':
    pass
else:
    feature_file = f"../processed_data/{args.dataset}/{args.dataset}_{args.feature}_list.pt"
    new_x = torch.load(feature_file)
    assert new_x.shape[0] == data.x.shape[0]
    data.x = torch.from_numpy(new_x)

data.y = data.y.flatten()
data = data.to(args.device)

print(f"Device: {args.device}")
print(f"Dataset: {args.dataset}")
print(f"Num of nodes: {num_nodes}")
print(f"Num of classes: {num_classes}")
print(f"Model: {args.model}")
print(f"Feature: {args.feature}")
print(f"Num layers: {args.num_layers}")
print(f"Num hidden dimension: {args.hidden_dim}")
print(f"Dropout: {args.dropout}")
print(f"Learning rate: {args.lr}")
print(f"Early stop: {args.early_stop}")

seeds = [0]
if args.dataset in ['cora', 'pubmed']: seeds = list(range(5))
test_accs = []
for seed in seeds:

    if args.dataset in ['cora', 'pubmed']:
        train_idx, valid_idx, test_idx = get_split_idx(split_folder, args.dataset, seed)
        data.train_mask = torch.tensor([x in train_idx for x in range(num_nodes)])
        data.val_mask = torch.tensor([x in valid_idx for x in range(num_nodes)])
        data.test_mask = torch.tensor([x in test_idx for x in range(num_nodes)])


    model = GNN(in_channels=data.x.shape[1],
                hidden_channels=args.hidden_dim,
                out_channels=num_classes,
                num_layers=args.num_layers,
                dropout=args.dropout).to(args.device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    trainable_params = sum(p.numel()
                            for p in model.parameters() if p.requires_grad)
    print(f"\nNumber of parameters: {trainable_params}")

    ckpt = f"saved_models/{args.dataset}_{seed}/{args.model}_{args.num_layers}_{args.hidden_dim}_{args.feature}.pt"
    stopper = EarlyStopping(
        patience=args.early_stop, path=ckpt) if args.early_stop > 0 else None

    if args.verbose:
        pbar = tqdm(range(args.epochs))
    else:
        pbar = range(args.epochs)
    for epoch in pbar:
        loss = train(model, optimizer, data, loss_func, eval_func)
        if args.dataset in ['ogbn-arxiv', 'ogbn-products']:
            train_acc, val_acc, test_acc, logits = evaluate(model, data, eval_func)
        else:
            train_acc, val_acc, test_acc, logits = evaluate(model, data, eval_func)
        es_sign, es_str = stopper.step(val_acc, model, epoch)

        if args.verbose:
            pbar.set_description(f"loss: {loss:.4} | train_acc: {train_acc:.4} | val_acc: {val_acc:.4} | test_acc: {test_acc:.4}")
        if es_sign == True:
            print(f"Seed: {seed} | Early stopping at epoch {epoch}")
            break